/**
 * @file div_custom.cpp
 *
 * Copyright (C) 2025. Huawei Technologies Co., Ltd. All rights reserved.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
 */
#include "div_custom_tiling.h"
#include "kernel_operator.h"

constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue

constexpr uint32_t DIV_BFLOAT16 = 0;
constexpr uint32_t DIV_FLOAT16 = 1;
constexpr uint32_t DIV_FLOAT32 = 2;
constexpr uint32_t DIV_INT8 = 3;
constexpr uint32_t DIV_INT16 = 4;
constexpr uint32_t DIV_INT32 = 5;

constexpr uint32_t LAST_TWO_TILE = 2;

template <typename dataType> class KernelDiv;


template <> class KernelDiv <int8_t> {
public:
    __aicore__ inline KernelDiv() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, DivCustomTilingData tiling)
    {
        if (tiling.isEvenCore) {
            this->blockLength = tiling.blockLength;
            this->tileNum = tiling.tileNum;
            this->tileLength = tiling.tileLength / BUFFER_NUM;
            this->lastTileLength = tiling.lastTileLength;

            xGm.SetGlobalBuffer((__gm__ int8_t *)x + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            yGm.SetGlobalBuffer((__gm__ int8_t *)y + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            zGm.SetGlobalBuffer((__gm__ int8_t *)z + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
        } else {
            if (AscendC::GetBlockIdx() < tiling.formerNum) {
                this->tileNum = tiling.formerTileNum;
                this->tileLength = tiling.formerTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.formerLastTileLength;

                xGm.SetGlobalBuffer((__gm__ int8_t *)x + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
                yGm.SetGlobalBuffer((__gm__ int8_t *)y + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
                zGm.SetGlobalBuffer((__gm__ int8_t *)z + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
            } else {
                this->tileNum = tiling.tailTileNum;
                this->tileLength = tiling.tailTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.tailLastTileLength;

                xGm.SetGlobalBuffer((__gm__ int8_t *)x + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
                yGm.SetGlobalBuffer((__gm__ int8_t *)y + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
                zGm.SetGlobalBuffer((__gm__ int8_t *)z + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
            }
        }
        pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(int8_t));
        pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(int8_t));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(int8_t));

        pipe.InitBuffer(tmpBuf0, this->tileLength * sizeof(half));
        pipe.InitBuffer(tmpBuf1, this->tileLength * sizeof(half));
    }
    __aicore__ inline void Process()
    {
        int32_t loopCount = this->tileNum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress)
    {
        AscendC::LocalTensor<int8_t> xLocal = inQueueX.AllocTensor<int8_t>();
        AscendC::LocalTensor<int8_t> yLocal = inQueueY.AllocTensor<int8_t>();
        if (progress == (this->tileNum * BUFFER_NUM - 1)) {
            AscendC::DataCopy(xLocal, xGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength],
                this->tileLength);
            AscendC::DataCopy(yLocal, yGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength],
                this->tileLength);
        } else {
            AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength);
            AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength);
        }
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    __aicore__ inline void Compute(int32_t progress)
    {
        AscendC::LocalTensor<int8_t> xLocal = inQueueX.DeQue<int8_t>();
        AscendC::LocalTensor<int8_t> yLocal = inQueueY.DeQue<int8_t>();
        AscendC::LocalTensor<int8_t> zLocal = outQueueZ.AllocTensor<int8_t>();

        AscendC::LocalTensor<half> tmpTensor0 = tmpBuf0.Get<half>();
        AscendC::LocalTensor<half> tmpTensor1 = tmpBuf1.Get<half>();

        AscendC::Cast(tmpTensor0, xLocal, AscendC::RoundMode::CAST_NONE, this->tileLength);
        AscendC::Cast(tmpTensor1, yLocal, AscendC::RoundMode::CAST_NONE, this->tileLength);

        AscendC::Div(tmpTensor0, tmpTensor0, tmpTensor1, this->tileLength);
        AscendC::Cast(zLocal, tmpTensor0, AscendC::RoundMode::CAST_NONE, this->tileLength);

        outQueueZ.EnQue<int8_t>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    __aicore__ inline void CopyOut(int32_t progress)
    {
        AscendC::LocalTensor<int8_t> zLocal = outQueueZ.DeQue<int8_t>();
        if (progress == (this->tileNum * BUFFER_NUM - 1)) {
            AscendC::DataCopy(zGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength], zLocal,
                this->tileLength);
        } else {
            AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength);
        }
        outQueueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf0;
    AscendC::TBuf<AscendC::TPosition::VECCALC> tmpBuf1;

    AscendC::GlobalTensor<int8_t> xGm;
    AscendC::GlobalTensor<int8_t> yGm;
    AscendC::GlobalTensor<int8_t> zGm;

    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;
};

template <typename dataType> class KernelDiv {
public:
    __aicore__ inline KernelDiv() {}
    __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z, DivCustomTilingData tiling)
    {
        if (tiling.isEvenCore) {
            this->blockLength = tiling.blockLength;
            this->tileNum = tiling.tileNum;
            this->tileLength = tiling.tileLength / BUFFER_NUM;
            this->lastTileLength = tiling.lastTileLength;

            xGm.SetGlobalBuffer((__gm__ dataType *)x + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            yGm.SetGlobalBuffer((__gm__ dataType *)y + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
            zGm.SetGlobalBuffer((__gm__ dataType *)z + this->blockLength * AscendC::GetBlockIdx(), this->blockLength);
        } else {
            if (AscendC::GetBlockIdx() < tiling.formerNum) {
                this->tileNum = tiling.formerTileNum;
                this->tileLength = tiling.formerTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.formerLastTileLength;

                xGm.SetGlobalBuffer((__gm__ dataType *)x + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
                yGm.SetGlobalBuffer((__gm__ dataType *)y + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
                zGm.SetGlobalBuffer((__gm__ dataType *)z + tiling.formerLength * AscendC::GetBlockIdx(), tiling.formerLength);
            } else {
                this->tileNum = tiling.tailTileNum;
                this->tileLength = tiling.tailTileLength / BUFFER_NUM;
                this->lastTileLength = tiling.tailLastTileLength;

                xGm.SetGlobalBuffer((__gm__ dataType *)x + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
                yGm.SetGlobalBuffer((__gm__ dataType *)y + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
                zGm.SetGlobalBuffer((__gm__ dataType *)z + tiling.formerLength * tiling.formerNum + 
                    tiling.tailLength * (AscendC::GetBlockIdx() - tiling.formerNum), tiling.tailLength);
            }
        }
        pipe.InitBuffer(inQueueX, BUFFER_NUM, this->tileLength * sizeof(dataType));
        pipe.InitBuffer(inQueueY, BUFFER_NUM, this->tileLength * sizeof(dataType));
        pipe.InitBuffer(outQueueZ, BUFFER_NUM, this->tileLength * sizeof(dataType));
    }
    __aicore__ inline void Process()
    {
        int32_t loopCount = this->tileNum * BUFFER_NUM;
        for (int32_t i = 0; i < loopCount; i++) {
            CopyIn(i);
            Compute(i);
            CopyOut(i);
        }
    }

private:
    __aicore__ inline void CopyIn(int32_t progress)
    {
        AscendC::LocalTensor<dataType> xLocal = inQueueX.AllocTensor<dataType>();
        AscendC::LocalTensor<dataType> yLocal = inQueueY.AllocTensor<dataType>();
        if (progress == (this->tileNum * BUFFER_NUM - 1)) {
            AscendC::DataCopy(xLocal, xGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength],
                this->tileLength);
            AscendC::DataCopy(yLocal, yGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength],
                this->tileLength);
        } else {
            AscendC::DataCopy(xLocal, xGm[progress * this->tileLength], this->tileLength);
            AscendC::DataCopy(yLocal, yGm[progress * this->tileLength], this->tileLength);
        }
        inQueueX.EnQue(xLocal);
        inQueueY.EnQue(yLocal);
    }
    __aicore__ inline void Compute(int32_t progress)
    {
        AscendC::LocalTensor<dataType> xLocal = inQueueX.DeQue<dataType>();
        AscendC::LocalTensor<dataType> yLocal = inQueueY.DeQue<dataType>();
        AscendC::LocalTensor<dataType> zLocal = outQueueZ.AllocTensor<dataType>();

        AscendC::Div(zLocal, xLocal, yLocal, this->tileLength);

        outQueueZ.EnQue<dataType>(zLocal);
        inQueueX.FreeTensor(xLocal);
        inQueueY.FreeTensor(yLocal);
    }
    __aicore__ inline void CopyOut(int32_t progress)
    {
        AscendC::LocalTensor<dataType> zLocal = outQueueZ.DeQue<dataType>();
        if (progress == (this->tileNum * BUFFER_NUM - 1)) {
            AscendC::DataCopy(zGm[(progress - LAST_TWO_TILE) * this->tileLength + this->lastTileLength], zLocal,
                this->tileLength);
        } else {
            AscendC::DataCopy(zGm[progress * this->tileLength], zLocal, this->tileLength);
        }
        outQueueZ.FreeTensor(zLocal);
    }

private:
    AscendC::TPipe pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueX;
    AscendC::TQue<AscendC::TPosition::VECIN, BUFFER_NUM> inQueueY;
    AscendC::TQue<AscendC::TPosition::VECOUT, BUFFER_NUM> outQueueZ;

    AscendC::GlobalTensor<dataType> xGm;
    AscendC::GlobalTensor<dataType> yGm;
    AscendC::GlobalTensor<dataType> zGm;

    uint32_t blockLength;
    uint32_t tileNum;
    uint32_t tileLength;
    uint32_t lastTileLength;
};
        
extern "C" __global__ __aicore__ void div_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z, DivCustomTilingData tiling)
{
    if (tiling.dataType == DIV_FLOAT16) {
        KernelDiv<half> op;
        op.Init(x, y, z, tiling);
        op.Process();
    } else {
        return;
    }
}
